# Copyright (c) OpenMMLab. All rights reserved.
import types

from xtuner._lite import get_logger
from xtuner._lite.accelerate.dispatches.huggingface import DISPATCH_MAP, _dispatch_forward_fn, _dispatch_rms_norm_forward


logger = get_logger()


def dispatch_internlm3_moe_varlen_attn_forward(module):
    assert module.__class__.__name__ == 'InternLM3MoEFlashAttention2'
    from .internlm3_moe import internlm3_moe_varlen_attn_forward
    from xtuner._lite.accelerate import varlen_attn_is_available
    if varlen_attn_is_available():
        _dispatch_forward_fn(module, internlm3_moe_varlen_attn_forward)
        return internlm3_moe_varlen_attn_forward.__name__


def dispatch_qwen2_moe_varlen_attn_forward(module):
    assert module.__class__.__name__ == 'Qwen2MoeFlashAttention2'
    from .qwen_moe import qwen2_varlen_attn_forward
    _dispatch_forward_fn(module, qwen2_varlen_attn_forward)
    return qwen2_varlen_attn_forward.__name__


DISPATCH_MAP['InternLM3MoEFlashAttention2'] = dispatch_internlm3_moe_varlen_attn_forward
DISPATCH_MAP['InternLM3MoERMSNorm'] = _dispatch_rms_norm_forward
DISPATCH_MAP['Qwen2MoeFlashAttention2'] = dispatch_qwen2_moe_varlen_attn_forward
DISPATCH_MAP['Qwen2MoeRMSNorm'] = _dispatch_rms_norm_forward


def dispatch_hf_code(model, exclude_cls=()):
    from xtuner._lite import get_logger
    logger = get_logger()

    for name, module in model.named_modules():
        cls_name = module.__class__.__name__
        if cls_name in exclude_cls:
            continue

        if cls_name in DISPATCH_MAP:
            dispatched = DISPATCH_MAP[cls_name](module)
            if dispatched is not None:
                logger.info(
                    f'Dispatch {name}({cls_name}) forward to `{dispatched}`')
